{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e2b62a7c",
   "metadata": {},
   "source": [
    "# Control MXU Floating Point Precision\n",
    " **Author:** `Yaoshiang Ho`\n",
    "\n",
    " **Date created:** 2025/05/15\n",
    "\n",
    " **Last modified:** 2025/05/15\n",
    "\n",
    " In this tutorial, you will learn how to control the floating point\n",
    " precision of matrix multiplication (mat mul) operations\n",
    " when using certain accelerators with PyTorch/XLA, such as TPUs.\n",
    " You will also learn how to access torch's floating point info\n",
    " and how to visually inspect the floating point representation of numbers."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66ff85ec",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    " Google TPUs are built with matrix multiplication optimized on silicon\n",
    " in a physical module called a Matrix Multiply Unit or MXU.\n",
    " To maintain speed, researchers identified an inexpensive tradeoff.\n",
    " [The research](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus)\n",
    " showed that neural networks were able to train with less precision\n",
    " than FP32\n",
    " \\\"without having any noticeable impact on model accuracy\\\".\n",
    " The same was not true for range. Due to operations like norms,\n",
    " FP32's range was important to keep. The solution was bfloat16:\n",
    " the same range as FP32, with less precision.\n",
    "\n",
    " Nvidia V100 and newer GPUs also include specialized matrix multiplication\n",
    " units called TensorCores. These GPUs use a numerical format called\n",
    " TF32, which has the same range as FP32 and bfloat16, but an\n",
    " intermediate precision (10 bits of mantissa)\n",
    " because TF32 only has 19 total bits.\n",
    "\n",
    " Matrix multiplication operations performed on FP32 values will\n",
    " yield results in bfloat16 for TPUs and TF32 (with 19 bits)\n",
    " for Nvidia GPUs.\n",
    "\n",
    " ![bits layout](../_static/img/bit_layout.svg)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e5e8fd",
   "metadata": {},
   "source": [
    "## Higher precision math on lower precision hardware\n",
    "\n",
    " Even with the 7 mantissa bits of bfloat16, it is possible to calculate\n",
    " math in higher precision. This is done by breaking up a number into\n",
    " its components. To build intuition, imagine an MXU\n",
    " that supports 2 digits in a base-10 (decimal) number system.\n",
    " The goal is to multiply numbers with\n",
    " 4 digits of precision, for example, $9.111$ and $9.222$.\n",
    " In infinite precision, the product is $84.021642$. Notice that\n",
    " two numbers with 4 digits of precision generates twice as many\n",
    " digits of precision in the result. But given the number format\n",
    " is 4 digits, the result will be rounded to $84.02$.\n",
    "\n",
    " The simplest approach is to round the numbers to $9.1$ and $9.2$,\n",
    " resulting in $83.72$. This is conceptually the \"default\" precision\n",
    " setting of PyTorch/XLA on TPU.\n",
    "\n",
    " The next approach is to break each number up into two parts,\n",
    " high and low (H and L):\n",
    " $(9.1 + 0.011) \\times (9.2 + 0.022)$.\n",
    " This equals $(H \\times H + H \\times L + L \\times H + L \\times L)$.\n",
    " The first three matrix\n",
    " multiplications comprise the three-pass approach and roughly\n",
    " doubles the effective precision. The fourth term, $L \\times L$, is ignored\n",
    " and looking at the result, $0.000242$, it is easy to see that\n",
    " this value will not contribute to the final result. Some values of\n",
    " $L \\times L$\n",
    " could generate a fourth term that moves the value by one bit, but adding\n",
    " one bit of information half the time provides little value relative\n",
    " to the cost of running another multiplication.\n",
    "\n",
    " ```text\n",
    "                   +--------+--------+\n",
    "                   | 9.222           |\n",
    "                   +--------+--------+\n",
    "                   | 9.2    | 0.022  |\n",
    " +--------+--------+--------+--------+\n",
    " |9.111   | 9.1    |83.72   | 0.2002 |\n",
    " +--------+--------+--------+--------+\n",
    " |        | 0.011  | 0.1012 |        | = 84.0214 => 84.02\n",
    " +--------+--------+--------+--------+\n",
    " ```\n",
    "\n",
    "\n",
    " Extending this approach again would yield roughly triple the precision.\n",
    " The idea is to break the number into\n",
    " into high, medium, and low (H, M, and L), generating nine possible terms:\n",
    " $(H + M + L) \\times (H + M + L) = HH + HM + MH + MM + HL + LH + ML + LM + LL$.\n",
    " The final three are ignored, and the first six comprise the six-pass approach.\n",
    " It is essentially equivalent to FP32, with some room for variances in the minor bit."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5cd81784",
   "metadata": {},
   "source": [
    "## PyTorch/XLA and TPUs\n",
    " PyTorch/XLA allows control of the one-pass, three-pass, and six-pass\n",
    " approaches in the `torch_xla.backends.set_mat_mul_precision()`\n",
    " function.\n",
    " The valid values are `default`, `high`, and `highest`. Now, you'll investigate\n",
    " the differences between these three settings.\n",
    "\n",
    " Warning: Although this notebook demonstrates setting precision multiple\n",
    " times\n",
    " it is recommended to only set the precision once at the beginning of your\n",
    " script."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08e2cb59",
   "metadata": {},
   "source": [
    "## Preparations\n",
    "\n",
    " Make sure you are running this tutorial on a TPU. You can\n",
    " access a TPU using Google Colab.\n",
    "\n",
    " Import the required packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "da0d842a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:08:59.346015Z",
     "iopub.status.busy": "2025-05-15T20:08:59.345651Z",
     "iopub.status.idle": "2025-05-15T20:09:02.266478Z",
     "shell.execute_reply": "2025-05-15T20:09:02.265986Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch_xla.backends\n",
    "\n",
    "torch.set_printoptions(precision=20, sci_mode=False, linewidth=240)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05c0809b",
   "metadata": {},
   "source": [
    "Epsilon is the minimum difference between 1.0 and the next highest representable number. Retrieve the value out from torch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a3fc145c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:02.268113Z",
     "iopub.status.busy": "2025-05-15T20:09:02.267838Z",
     "iopub.status.idle": "2025-05-15T20:09:02.270606Z",
     "shell.execute_reply": "2025-05-15T20:09:02.270204Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bfloat16 epsilon: 0.0078125\n",
      "return type of torch.finfo: <class 'float'>\n"
     ]
    }
   ],
   "source": [
    "eps = torch.finfo(torch.bfloat16).eps\n",
    "print(f\"bfloat16 epsilon: {eps}\")\n",
    "print(f\"return type of torch.finfo: {type(eps)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7dff7be6",
   "metadata": {},
   "source": [
    "The epsilon is also defined as 1 / 2^p, where p is the number of bits in the mantissa."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc2b5702",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:02.271923Z",
     "iopub.status.busy": "2025-05-15T20:09:02.271691Z",
     "iopub.status.idle": "2025-05-15T20:09:02.274030Z",
     "shell.execute_reply": "2025-05-15T20:09:02.273621Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0078125\n"
     ]
    }
   ],
   "source": [
    "print(1 / (2**7))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef147654",
   "metadata": {},
   "source": [
    "Numbers in between may get rounded up to 1.0 + epsilon, or down to 1.0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "80595534",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:02.275297Z",
     "iopub.status.busy": "2025-05-15T20:09:02.275092Z",
     "iopub.status.idle": "2025-05-15T20:09:02.278882Z",
     "shell.execute_reply": "2025-05-15T20:09:02.278473Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.00000000000000000000, 1.00000000000000000000, 1.00000000000000000000, 1.00781250000000000000, 1.00781250000000000000], dtype=torch.bfloat16)\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    torch.tensor(\n",
    "        [1.0, 1.0 + eps / 4.0, 1.0 + eps / 2, 1.0 + eps * 3 / 4, 1.0 + eps],\n",
    "        dtype=torch.bfloat16,\n",
    "    ))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a54979f0",
   "metadata": {},
   "source": [
    "## Get ready to look directly at bits\n",
    "\n",
    " Set up tools to convert binary strings to FP32 numbers, and vice\n",
    " versa. Create a function to generate a random matrix.\n",
    "\n",
    " In general, when\n",
    " testing an MXU (or TensorCore), pass matrices to encourage XLA to\n",
    " use the MXU rather than the slower but more precise units for FP32."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e8d9050",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:02.280122Z",
     "iopub.status.busy": "2025-05-15T20:09:02.279908Z",
     "iopub.status.idle": "2025-05-15T20:09:02.283995Z",
     "shell.execute_reply": "2025-05-15T20:09:02.283591Z"
    }
   },
   "outputs": [],
   "source": [
    "import struct\n",
    "\n",
    "\n",
    "def binary_fraction_to_fp32(bstr: str) -> float:\n",
    "  if bstr[:4] != \"0b1.\":\n",
    "    raise ValueError(f\"Invalid binary string: {bstr}\")\n",
    "  fraction_bits = bstr[4:]\n",
    "  mantissa = 1.0\n",
    "  for i, bit in enumerate(fraction_bits):\n",
    "    mantissa += int(bit) * 2**-(i + 1)\n",
    "  return float(mantissa)\n",
    "\n",
    "\n",
    "def fp32_to_binary_fraction(fp32_float: float) -> str:\n",
    "  x_bytes = struct.pack(\">f\", fp32_float)  # Big-endian IEEE 754 float32\n",
    "  as_int = struct.unpack(\">I\", x_bytes)[0]  # Interpret bits as uint32\n",
    "  sign = (as_int >> 31) & 0b1\n",
    "  exponent = (as_int >> 23) & 0xFF\n",
    "  mantissa = as_int & 0x7FFFFF  # lower 23 bits\n",
    "  return f\"FORMAT:0b SIGN:{sign} EXPONENT:{exponent:08b} MANTISSA:{mantissa:023b} VALUE={fp32_float}\"\n",
    "\n",
    "\n",
    "def get_rand_matrix():\n",
    "  \"\"\"Returns a diagonal matrix of shape 1024, 1024, values between 0.999 and 1.111\"\"\"\n",
    "  eye = torch.eye(1024, dtype=torch.float32, device=\"xla\")\n",
    "  rand_ = torch.rand(\n",
    "      (1024, 1024), dtype=torch.float32, device=\"xla\") * 0.2 + 0.9\n",
    "  result = eye * rand_\n",
    "  assert torch.nonzero(result).size(0) == 1024, torch.nonzero(result).size(0)\n",
    "  return result"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "876974bb",
   "metadata": {},
   "source": [
    "## Examining a number\n",
    "\n",
    " Generate an FP32 number representing 1 + bf16_eps/2. This\n",
    " will put one extra bit out of reach of a bfloat16's mantissa."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "aea31981",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:02.285188Z",
     "iopub.status.busy": "2025-05-15T20:09:02.285017Z",
     "iopub.status.idle": "2025-05-15T20:09:02.287520Z",
     "shell.execute_reply": "2025-05-15T20:09:02.287126Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FP32     : 1.00390625\n",
      "1 + eps/2: 1.00390625\n"
     ]
    }
   ],
   "source": [
    "one_plus_half_eps = binary_fraction_to_fp32(\"0b1.\" + \"0\" * 7 + \"1\" + \"0\" * 15)\n",
    "print(f\"FP32     : {one_plus_half_eps }\")\n",
    "print(f\"1 + eps/2: {1.0 + eps / 2}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aab8ded",
   "metadata": {},
   "source": [
    "Print the bits for FP32 and BF16. Notice that the 8th bit is lost.\n",
    " This reconfirms that BF16 cannot represent the 8th bit of precision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fc626b20",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:02.288805Z",
     "iopub.status.busy": "2025-05-15T20:09:02.288596Z",
     "iopub.status.idle": "2025-05-15T20:09:02.291240Z",
     "shell.execute_reply": "2025-05-15T20:09:02.290846Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FP32: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625\n",
      "BF16: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000000000000000000000 VALUE=1.0\n"
     ]
    }
   ],
   "source": [
    "print(f\"FP32: {fp32_to_binary_fraction(one_plus_half_eps)}\")\n",
    "ones_bf16 = torch.tensor(\n",
    "    one_plus_half_eps, dtype=torch.bfloat16).to(torch.float32).item()\n",
    "print(f\"BF16: {fp32_to_binary_fraction(ones_bf16)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2eff9e6f",
   "metadata": {},
   "source": [
    "## MXU\n",
    "\n",
    " Place your numbers of interest in a diagonal matrix. By putting them in\n",
    " a matrix, XLA will execute the math on the MXU. By making the matrices\n",
    " diagonal, the math will be equivalent to element-wise multiplication.\n",
    "\n",
    " Notice that the values are essentially rounded down to 1.0 before\n",
    " being multiplied, resulting in 1.0 as the output.\n",
    " This is the loss of precision that occurs in a TPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ce0df76e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:02.292448Z",
     "iopub.status.busy": "2025-05-15T20:09:02.292247Z",
     "iopub.status.idle": "2025-05-15T20:09:17.401029Z",
     "shell.execute_reply": "2025-05-15T20:09:17.400357Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625\n",
      "Y: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000001000000000000000 VALUE=1.00390625\n",
      "Z: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000000000000000000000 VALUE=1.0\n"
     ]
    }
   ],
   "source": [
    "X = get_rand_matrix()\n",
    "Y = get_rand_matrix()\n",
    "X[0, 0] = one_plus_half_eps\n",
    "Y[0, 0] = one_plus_half_eps\n",
    "Z = torch.matmul(X, Y)\n",
    "print(f\"X: {fp32_to_binary_fraction(X[0][0].item())}\")\n",
    "print(f\"Y: {fp32_to_binary_fraction(Y[0][0].item())}\")\n",
    "print(f\"Z: {fp32_to_binary_fraction(Z[0][0].item())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "598de4b7",
   "metadata": {},
   "source": [
    "## FP32 precision on bfloat16 hardware\n",
    "\n",
    " The 3 and 6 pass approaches generate more bits of precision.\n",
    " Turn on the highest precision mode (six passes) and run\n",
    " the experiment again. Notice that the TPU has calculated FP32 precision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ebfa965b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:17.402608Z",
     "iopub.status.busy": "2025-05-15T20:09:17.402427Z",
     "iopub.status.idle": "2025-05-15T20:09:17.772271Z",
     "shell.execute_reply": "2025-05-15T20:09:17.771603Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:torch_xla.backends:Setting mat mul precision multiple times is not recommended. If you need to do so, please empirically verify that the precision setting is behaving as expected.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Z_ref: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000010000000010000000 VALUE=1.0078277587890625\n",
      "Z:     FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:00000010000000010000000 VALUE=1.0078277587890625\n"
     ]
    }
   ],
   "source": [
    "Z_ref = torch.matmul(\n",
    "    X.to(\"cpu\").to(torch.float32),\n",
    "    Y.to(\"cpu\").to(torch.float32))\n",
    "print(f\"Z_ref: {fp32_to_binary_fraction(Z_ref[0][0].item())}\")\n",
    "torch_xla.backends.set_mat_mul_precision(\"highest\")\n",
    "Z = torch.matmul(X, Y)\n",
    "print(f\"Z:     {fp32_to_binary_fraction(Z[0][0].item())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d14d7dfe",
   "metadata": {},
   "source": [
    "## Edge-case numbers\n",
    "\n",
    " In the previous example, you saw no difference between\n",
    " the six-pass and FP32 multiplication. Now, you will use an edge\n",
    " case number to demonstrate a difference in the\n",
    " final bit between the six-pass approach and full FP32."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8011e785",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2025-05-15T20:09:17.773780Z",
     "iopub.status.busy": "2025-05-15T20:09:17.773587Z",
     "iopub.status.idle": "2025-05-15T20:09:18.408244Z",
     "shell.execute_reply": "2025-05-15T20:09:18.407613Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:torch_xla.backends:Setting mat mul precision multiple times is not recommended. If you need to do so, please empirically verify that the precision setting is behaving as expected.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Z_ref: FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:01110000101000111101100 VALUE=1.440000057220459\n",
      "Z:     FORMAT:0b SIGN:0 EXPONENT:01111111 MANTISSA:01110000101000111101101 VALUE=1.4400001764297485\n"
     ]
    }
   ],
   "source": [
    "X = get_rand_matrix()\n",
    "Y = get_rand_matrix()\n",
    "X[0, 0] = 1.2\n",
    "Y[0, 0] = 1.2\n",
    "Z_ref = torch.matmul(\n",
    "    X.to(\"cpu\").to(torch.float32),\n",
    "    Y.to(\"cpu\").to(torch.float32))\n",
    "print(f\"Z_ref: {fp32_to_binary_fraction(Z_ref[0][0].item())}\")\n",
    "torch_xla.backends.set_mat_mul_precision(\"highest\")\n",
    "Z = torch.matmul(X, Y)\n",
    "print(f\"Z:     {fp32_to_binary_fraction(Z[0][0].item())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab0a2c9a",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    " In this tutorial, you learned how to control the floating point\n",
    " precision of your matrix multiplication (mat mul) operations.\n",
    " You also learned the internal algorithm used to generate\n",
    " higher precision through the three-pass and six-pass approaches."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}